{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0.4324468 , 0.48173807, 0.98724231, 0.51548606, 0.71232303,\n",
       "        0.97799668, 0.60370915, 0.32634855, 0.38733207, 0.08664855]),\n",
       " [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1]])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "#每个老虎机的中奖概率,0-1之间的均匀分布\n",
    "probs = np.random.uniform(size=10)\n",
    "\n",
    "#记录每个老虎机的返回值\n",
    "rewards = [[1] for _ in range(10)]\n",
    "\n",
    "probs, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#随机选择的概率递减的贪婪算法\n",
    "def choose_one():\n",
    "    #求出每个老虎机各玩了多少次\n",
    "    played_count = [len(i) for i in rewards]\n",
    "    played_count = np.array(played_count)\n",
    "\n",
    "    #求出上置信界\n",
    "    #分子是总共玩了多少次,取根号后让他的增长速度变慢\n",
    "    #分母是每台老虎机玩的次数,乘以2让他的增长速度变快\n",
    "    #随着玩的次数增加,分母会很快超过分子的增长速度,导致分数越来越小\n",
    "    #具体到每一台老虎机,则是玩的次数越多,分数就越小,也就是ucb的加权越小\n",
    "    #所以ucb衡量了每一台老虎机的不确定性,不确定性越大,探索的价值越大\n",
    "    fenzi = played_count.sum()**0.5\n",
    "    fenmu = played_count * 2\n",
    "    ucb = fenzi / fenmu\n",
    "\n",
    "    #ucb本身取根号\n",
    "    #大于1的数会被缩小,小于1的数会被放大,这样保持ucb恒定在一定的数值范围内\n",
    "    ucb = ucb**0.5\n",
    "\n",
    "    #计算每个老虎机的奖励平均\n",
    "    rewards_mean = [np.mean(i) for i in rewards]\n",
    "    rewards_mean = np.array(rewards_mean)\n",
    "\n",
    "    #ucb和期望求和\n",
    "    ucb += rewards_mean\n",
    "\n",
    "    return ucb.argmax()\n",
    "\n",
    "\n",
    "choose_one()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1, 1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def try_and_play():\n",
    "    i = choose_one()\n",
    "\n",
    "    #玩老虎机,得到结果\n",
    "    reward = 0\n",
    "    if random.random() < probs[i]:\n",
    "        reward = 1\n",
    "\n",
    "    #记录玩的结果\n",
    "    rewards[i].append(reward)\n",
    "\n",
    "\n",
    "try_and_play()\n",
    "\n",
    "rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 312
    },
    "executionInfo": {
     "elapsed": 676,
     "status": "ok",
     "timestamp": 1649954384006,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "wIHh_wRA8YDz",
    "outputId": "d5d65ff2-744d-44e2-ec8a-eb78d13397c2",
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4936.211534652689, 4553)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_result():\n",
    "    #玩N次\n",
    "    for _ in range(5000):\n",
    "        try_and_play()\n",
    "\n",
    "    #期望的最好结果\n",
    "    target = probs.max() * 5000\n",
    "\n",
    "    #实际玩出的结果\n",
    "    result = sum([sum(i) for i in rewards])\n",
    "\n",
    "    return target, result\n",
    "\n",
    "\n",
    "get_result()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第2章-多臂老虎机问题.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
